{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jp\n",
    "import mediapy as media\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from etils import epath\n",
    "\n",
    "from mujoco_playground import BraxEnvWrapper, locomotion\n",
    "\n",
    "# Enable persistent compilation cache.\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"Go1JoystickFlatTerrain\"\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "env = locomotion.load(env_name, config=env_cfg)\n",
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CKPT_PATH = \"/home/kevin/dev/mjx_control/learning/checkpoints/Go1JoystickFlatTerrain-20241112-170242\"\n",
    "CKPT_PATH = epath.Path(CKPT_PATH)\n",
    "assert CKPT_PATH.exists()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latest_ckpts = list(CKPT_PATH.glob(\"*\"))\n",
    "if not latest_ckpts:\n",
    "  raise ValueError(\"No checkpoints found\")\n",
    "# Remove .json files.\n",
    "latest_ckpts = [\n",
    "    ckpt for ckpt in latest_ckpts if not ckpt.name.endswith(\".json\")\n",
    "]\n",
    "latest_ckpts.sort(key=lambda x: int(x.name))\n",
    "print(f\"Found {len(latest_ckpts)} checkpoints\")\n",
    "latest_ckpt = latest_ckpts[-1]\n",
    "restore_checkpoint_path = latest_ckpt\n",
    "print(f\"Restoring from {restore_checkpoint_path}\")\n",
    "\n",
    "make_networks_factory = functools.partial(\n",
    "    ppo_networks.make_ppo_networks,\n",
    "    policy_hidden_layer_sizes=(128, 128, 128, 128),\n",
    ")\n",
    "\n",
    "train_fn = functools.partial(\n",
    "    ppo.train,\n",
    "    num_timesteps=0,\n",
    "    episode_length=1000,\n",
    "    normalize_observations=True,\n",
    "    restore_checkpoint_path=restore_checkpoint_path,\n",
    "    network_factory=make_networks_factory,\n",
    "    num_envs=2,\n",
    ")\n",
    "make_inference_fn, params, _ = train_fn(environment=BraxEnvWrapper(env))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_fn = make_inference_fn(params, deterministic=True)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmd = jp.array([0.3, 0, 0.1])\n",
    "rng = jax.random.PRNGKey(12345)\n",
    "rollout = []\n",
    "for _ in range(1):\n",
    "  rng, reset_rng = jax.random.split(rng)\n",
    "  state = jit_reset(reset_rng)\n",
    "  state.info[\"command\"] = cmd\n",
    "  for i in range(env_cfg.episode_length):\n",
    "    act_rng, rng = jax.random.split(rng)\n",
    "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "    state = jit_step(state, ctrl)\n",
    "    state.info[\"command\"] = cmd\n",
    "    rollout.append(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_every = 2\n",
    "fps = 1.0 / eval_env.dt / render_every\n",
    "traj = rollout[::render_every]\n",
    "\n",
    "frames = eval_env.render(\n",
    "    traj,\n",
    "    camera=\"track\",\n",
    "    height=480,\n",
    "    width=640,\n",
    ")\n",
    "media.show_video(frames, fps=fps, loop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
